/**
 * WARNING:
 *
 *   Any code here may be moved to the h2o-droplets repository in the future!
 */

package main.java.droplets;

import water.H2O;
import water.Key;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NFSFileVec;
import water.fvec.Vec;

import java.io.File;
import java.text.DecimalFormat;
import java.util.Random;

/**
 * Simplified version of H2O k-means algorithm for better readability.
 */
public class KMeansDroplet {
  public static void initCloud() {
    // Setup cloud name
    String[] args = new String[] { "-name", "h2o_test_cloud"};
    // Build a cloud of 1
    H2O.main(args);
    H2O.waitForCloudSize(1, 10*1000 /* ms */);
  }

  public static void main(String[] args) throws Exception {
    initCloud();

    // Load and parse a file. Data is distributed to other nodes in a round-robin way
    File f = new File("smalldata/glm_test/gaussian.csv");
    NFSFileVec nfs = NFSFileVec.make(f);
    Frame frame = water.parser.ParseDataset.parse(Key.make(),nfs._key);

    // Optionally create a frame with fewer columns, e.g. skip first
    frame.remove(0);

    // Create k centers as arrays of doubles
    int k = 7;
    double[][] centers = new double[k][frame.vecs().length];

    // Initialize first cluster center to random row
    Random rand = new Random();
    for( int cluster = 0; cluster < centers.length; cluster++ ) {
      long row = Math.max(0, (long) (rand.nextDouble() * frame.vecs().length) - 1);
      for( int i = 0; i < frame.vecs().length; i++ ) {
        Vec v = frame.vecs()[i];
        centers[cluster][i] = v.at(row);
      }
    }

    // Iterate over the dataset and show error for each step
    int NUM_ITERS = 10;
    for( int i = 0; i < NUM_ITERS; i++ ) {
      KMeans task = new KMeans();
      task._centers = centers;
      task.doAll(frame);

      for( int c = 0; c < centers.length; c++ ) {
        if( task._size[c] > 0 ) {
          for( int v = 0; v < frame.vecs().length; v++ ) {
            double value = task._sums[c][v] / task._size[c];
            centers[c][v] = value;
          }
        }
      }
      System.out.println("Error is " + task._error);
    }

    System.out.println("Cluster Centers:");
    DecimalFormat df = new DecimalFormat("#.00");
    for (double[] center : centers) {
      for (int v = 0; v < frame.vecs().length; v++)
        System.out.print(df.format(center[v]) + ", ");
      System.out.println("");
    }

    System.exit(0);
  }


  /**
   * For more complex tasks like this one, it is useful to marks fields that are provided by the
   * caller (IN), and fields generated by the task (OUT). IN fields can then be set to null when the
   * task is done using them, so that they do not get serialized back to the caller.
   */
  public static class KMeans extends MRTask<KMeans> {
    double[][] _centers; // IN:  Centroids/cluster centers

    double[][] _sums;    // OUT: Sum of features in each cluster
    int[]  _size;        // OUT: Row counts in each cluster
    double _error;       // OUT: Total sqr distance

    @Override public void map(Chunk[] chunks) {
      _sums = new double[_centers.length][chunks.length];
      _size = new int[_centers.length];

      // Find nearest cluster for each row
      for( int row = 0; row < chunks[0]._len; row++ ) {
        int nearest = -1;
        double minSqr = Double.MAX_VALUE;
        for( int cluster = 0; cluster < _centers.length; cluster++ ) {
          double sqr = 0;           // Sum of dimensional distances
          for( int column = 0; column < chunks.length; column++ ) {
            double delta = chunks[column].at0(row) - _centers[cluster][column];
            sqr += delta * delta;
          }
          if( sqr < minSqr ) {
            nearest = cluster;
            minSqr = sqr;
          }
        }
        _error += minSqr;

        // Add values and increment counter for chosen cluster
        for( int column = 0; column < chunks.length; column++ )
          _sums[nearest][column] += chunks[column].at0(row);
        _size[nearest]++;
      }
      _centers = null;
    }

    @Override public void reduce(KMeans task) {
      for( int cluster = 0; cluster < _size.length; cluster++ ) {
        for( int column = 0; column < _sums[0].length; column++ )
          _sums[cluster][column] += task._sums[cluster][column];
        _size[cluster] += task._size[cluster];
      }
      _error += task._error;
    }
  }
}